{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "32c165e0-7c5d-4285-b6d0-6aa1d665665d",
   "metadata": {},
   "source": [
    "## convert xml to [yolov5-rotation](https://github.com/acai66/yolov5_rotation) format. [旋转版yolov5](https://github.com/acai66/yolov5_rotation)标签格式转换"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3180461b-8dae-46eb-9886-d95ff517e6d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import xml.etree.ElementTree as ET\n",
    "from tqdm import tqdm # pip install tqdm\n",
    "import os\n",
    "import math"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e4b10907-e99c-4e24-93ad-ce5c65f130bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "workdir = './data_wire3/' # datasets root path. 数据集路径\n",
    "images_dir = os.path.join(workdir, 'all_images') # images path. 图像路径\n",
    "labels_dir = os.path.join(workdir, 'all_labels') # xml labels path. xml标签路径\n",
    "# yolov5_all_images = os.path.join(workdir, 'yolov5_all_images') # all images for yolov5 rotation. 转换后旋转版yolov5可用的图像路径\n",
    "yolov5_all_labels = os.path.join(workdir, 'yolov5_all_labels') # all labels for yolov5 rotation. 转换后旋转版yolov5可用的txt标签路径\n",
    "for d in [yolov5_all_labels]:\n",
    "    if not os.path.exists(d):\n",
    "        os.mkdir(d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6c4b69fa-26cc-4bf8-b53d-d4ad645bcb10",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_files = [i for i in os.listdir(labels_dir) if i[-4:] == '.xml']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d99e0116-ae7e-4bd2-adac-8040f60e3f46",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "labels count:  311\n"
     ]
    }
   ],
   "source": [
    "print('labels count: ', len(all_files))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4313ab5a-f76e-497b-97ec-f4ac83e9e970",
   "metadata": {},
   "outputs": [],
   "source": [
    "keep_class_names = [  ] # auto scan if blank. 如果留空，会自动扫描类别"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "64016188-8f87-4f86-adb4-f28283641731",
   "metadata": {},
   "outputs": [],
   "source": [
    "class_names = dict(zip(keep_class_names, range(len(keep_class_names))))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a983416f-ac2b-4051-a8a2-c271815c5fb7",
   "metadata": {},
   "source": [
    "## convert. 转换"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "260e49ec-7e07-4dc0-9b51-2ba48a9865a6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 17%|█▋        | 53/311 [00:00<00:00, 463.35it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Auto scan classnames enabled.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 311/311 [00:00<00:00, 473.06it/s]\n"
     ]
    }
   ],
   "source": [
    "auto_scan = False\n",
    "class_index = 0\n",
    "if len(class_names) == 0:\n",
    "    auto_scan = True\n",
    "    print('Auto scan classnames enabled.')\n",
    "\n",
    "for file in tqdm(all_files):\n",
    "    file_path = os.path.join(labels_dir, file)\n",
    "    tree = ET.parse(file_path)\n",
    "    root = tree.getroot()\n",
    "    \n",
    "    img_size = root.find('size')\n",
    "    width = float(img_size.find('width').text)\n",
    "    height = float(img_size.find('height').text)\n",
    "    \n",
    "    objs = root.findall('object')\n",
    "    \n",
    "    with open(os.path.join(yolov5_all_labels, file[:-4] + '.txt'), 'w+') as f:\n",
    "        for obj in objs:\n",
    "            name = obj.find('name').text.strip()\n",
    "            if name not in class_names.keys():\n",
    "                if auto_scan:\n",
    "                    class_names[name] = class_index\n",
    "                    class_index = class_index + 1\n",
    "                else:\n",
    "                    continue\n",
    "            rbb = obj.find('robndbox')\n",
    "            if not rbb:\n",
    "                print('no robndbox in %s' % (file_path))\n",
    "            cx = float(rbb.find('cx').text)\n",
    "            cy = float(rbb.find('cy').text)\n",
    "            w = float(rbb.find('w').text)\n",
    "            h = float(rbb.find('h').text)\n",
    "            angle = float(rbb.find('angle').text)\n",
    "            \n",
    "            degree = round(angle / math.pi * 180)\n",
    "            if h > w:     # swap w,h if w is not longside. 宽不是长边时，交换宽高\n",
    "                w, h = h, w\n",
    "                if degree < 90:\n",
    "                    degree = degree + 90\n",
    "                else:\n",
    "                    degree = degree - 90\n",
    "            cv_degree = 180 - degree     # opencv angle format. opencv格式角度\n",
    "            if cv_degree == 180:\n",
    "                cv_degree = 0\n",
    "            assert cv_degree >= 0 and cv_degree < 180\n",
    "            \n",
    "            f.write('{} {} {} {} {} {}\\n'.format(class_names[name], cx/width, cy/height, w/width, h/width, cv_degree))\n",
    "    # break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d5949e78-4239-4f95-bc8e-1e8487ea5b33",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "class_names:  {'0': 0}\n"
     ]
    }
   ],
   "source": [
    "print('class_names: ', class_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d685ac8c-d1d4-435e-b694-0c97960d7192",
   "metadata": {},
   "outputs": [],
   "source": [
    "sorted_keys = [i[0] for i in sorted(class_names.items(), key = lambda kv:(kv[1], kv[0]))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "84890397-d5fc-4683-9e5a-ab175a5cdcc5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "names: [ \"0\" ]\n"
     ]
    }
   ],
   "source": [
    "print('names: [ \"{}\" ]'.format('\", \"'.join(sorted_keys)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3e122a8-6852-4fcb-957a-4389e923d972",
   "metadata": {},
   "source": [
    "## split datasets for yolov5. 将转换格式后的数据集按yolov5方式组织"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "252717a6-24a0-4c86-810e-4de3ac097181",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import shutil"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3b3a3d9c-eb81-403f-89c0-47fc2cdc17e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "imgages = os.path.join(workdir, 'images')\n",
    "labels = os.path.join(workdir, 'labels')\n",
    "train_img = os.path.join(imgages, 'train')\n",
    "val_img = os.path.join(imgages, 'val')\n",
    "train_label = os.path.join(labels, 'train')\n",
    "val_label = os.path.join(labels, 'val')\n",
    "for d in [imgages, labels, train_img, val_img, train_label, val_label]:\n",
    "    if not os.path.exists(d):\n",
    "        os.mkdir(d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "b1012ce8-371d-47cf-b11d-461cc061337e",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_txts = [i for i in os.listdir(yolov5_all_labels) if i[-4:] == '.txt']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "29f764d5-5670-478e-9415-302905f07456",
   "metadata": {},
   "outputs": [],
   "source": [
    "random.shuffle(all_txts) # shuffle data. 随机打乱顺序"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42ee61b4-756d-4bb0-8dd1-b4fd4b9051da",
   "metadata": {},
   "source": [
    "## split (train, test) rate. 随机划分数据集，比例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "ec3a383f-8467-437b-b245-49e58b28e201",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_factor = 0.8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "875972bf-e3c6-453f-9145-7bd24c251f9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "inds = int(train_factor * len(all_txts))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "89543a08-3925-445b-8b08-2ee3e4585e31",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "all count:  311\n",
      "train count:  248\n",
      "test count:  63\n"
     ]
    }
   ],
   "source": [
    "print('all count: ', len(all_txts))\n",
    "print('train count: ', inds)\n",
    "print('test count: ', len(all_txts) - inds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "2d27b6d3-c596-4e0b-9988-14a9fcb6369c",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_txts = all_txts[:inds]\n",
    "val_txts = all_txts[inds:]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67b393cf-cdf9-42b6-a0c9-ae00a0dadb04",
   "metadata": {},
   "source": [
    "## copy images and labels to train/val dirs, only run once. 只执行一次"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "b323cebe-67c9-4cd0-8047-0593dfae43e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "run = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "d5c8deaa-637a-4928-b6da-c5d8e6991de2",
   "metadata": {},
   "outputs": [],
   "source": [
    "if run:\n",
    "    for txt in train_txts:\n",
    "        shutil.copyfile(os.path.join(yolov5_all_labels, txt), os.path.join(train_label, txt))\n",
    "        src = os.path.join(images_dir, txt[:-4] + '.jpg')\n",
    "        if os.path.exists(src):\n",
    "            shutil.copyfile(src, os.path.join(train_img, txt[:-4] + '.jpg'))\n",
    "        else:\n",
    "            shutil.copyfile(os.path.join(images_dir, txt[:-4] + '.png'), os.path.join(train_img, txt[:-4] + '.png'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "f61603ac-9fbc-4315-9ef7-05545e2ed5f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "if run:\n",
    "    for txt in val_txts:\n",
    "        shutil.copyfile(os.path.join(yolov5_all_labels, txt), os.path.join(val_label, txt))\n",
    "        src = os.path.join(images_dir, txt[:-4] + '.jpg')\n",
    "        if os.path.exists(src):\n",
    "            shutil.copyfile(src, os.path.join(val_img, txt[:-4] + '.jpg'))\n",
    "        else:\n",
    "            shutil.copyfile(os.path.join(images_dir, txt[:-4] + '.png'), os.path.join(val_img, txt[:-4] + '.png'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25ef6a1d-e0c3-45b1-be53-1b2967c5d7b7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
